'''Train CIFAR10 with PyTorch.'''
import inspect
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

import torchvision
import torchvision.transforms as transforms
import numpy as np
import random
#from utils import progress_bar


def setup_seed(seed):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)  # cpu
    torch.cuda.manual_seed_all(seed)  # 并行gpu
    torch.backends.cudnn.deterministic = True  # cpu/gpu结果一致



def gen_data(root, bs=128):
    """
         image_data_param {
        source: "myexamples/cifar10/data/train.txt"
        batch_size: 128
        shuffle: true
      }
    :param root:
    :param bs:
    :return:
    """
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    trainset = torchvision.datasets.CIFAR10(root=root, train=True, download=True, transform=transform_train)

    trainloader = torch.utils.data.DataLoader(torch.utils.data.Subset(trainset, range(0, 45000))
                                              , batch_size=bs, shuffle=True, num_workers=4, pin_memory=True, )

    validloader = torch.utils.data.DataLoader(torch.utils.data.Subset(trainset, range(45000, 50000))
                                              , batch_size=bs, shuffle=False, num_workers=4, pin_memory=True, )



    testset = torchvision.datasets.CIFAR10(root=root, train=False, download=True, transform=transform_test)
    testloader = torch.utils.data.DataLoader(testset, batch_size=bs,
                                             shuffle=False, num_workers=4 ,pin_memory=True,)

    return trainloader,validloader, testloader


# if device == 'cuda':
#     net = torch.nn.DataParallel(net)
#     cudnn.benchmark = True

best_acc = 0  # best test accuracy


# Training
def train(loader, model, opt, loss_func, epoch, device):
    print('\nEpoch: %d' % epoch)
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(loader):
        inputs, targets = inputs.to(device), targets.to(device)
        opt.zero_grad()
        outputs = model(inputs, targets)
        loss = loss_func(outputs, targets)
        loss.backward()
        opt.step()

        train_loss += loss.item()
        _, predicted = loss_func.predict(outputs).max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        progress_bar(batch_idx, len(loader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                     % (train_loss / (batch_idx + 1), 100. * correct / total, correct, total))

    return train_loss / (batch_idx + 1)


def test(loader, model, loss_func, epoch, device):
    global best_acc
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(loader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = loss_func(outputs, targets)

            test_loss += loss.item()
            _, predicted = loss_func.predict(outputs).max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            progress_bar(batch_idx, len(loader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                         % (test_loss / (batch_idx + 1), 100. * correct / total, correct, total))

    # Save checkpoint.
    acc = 100. * correct / total
    return test_loss / (batch_idx + 1), acc


def save_arch(model, loss_func, optimizer, f):
    arch_state = {'model': model.state_dict(),
                  'loss_func': loss_func.state_dict(),
                  'opt': optimizer.state_dict()}

    torch.save(arch_state, f)


def load_arch(model, loss_func, optimizer, f):
    state = torch.load(f)
    model.load_state_dict(state['model'])
    loss_func.load_state_dict(state['loss_func'])
    optimizer.load_state_dict(state['opt'])
